Skip to content

Add Full TE Spec support for Megatron Pruning DynamicModules + MoE bug fixes#1024

Merged
kevalmorabia97 merged 9 commits intomainfrom
kmorabia/minitron-full-te-spec
Mar 20, 2026
Merged

Add Full TE Spec support for Megatron Pruning DynamicModules + MoE bug fixes#1024
kevalmorabia97 merged 9 commits intomainfrom
kmorabia/minitron-full-te-spec

Conversation

@kevalmorabia97
Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 commented Mar 11, 2026

What does this PR do?

Type of change: Improvement + Bug Fix

Quantization recently added support for Full TE spec. Adding same for Pruning as well so we can retire ModelOpt spec and just use standard TE spec.
NOTE: We still dont support TEGroupedGemm and instead use TE SequentialMLP for now (but this can be configured in standard TE Spec so we dont need modelopt spec)

Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.

[Bug fix]: Previously NAS-based pruning for MoE models would hang when evaluating MMLU for pruned candidate models because of a bug. Fixed in this PR as well

[Bug fix]: Previously hidden size importance hooks were not applied to pre_mlp_layernorm for MoE layers. Fixed in this PR as well resulting in a significant improvement in MMLU for Qwen3-30B-A3B

Testing

  • Unit tests updated and passing
  • Compare pruning results for Qwen3-8B -> 6B. ⚠️ Difference in MMLU scores resulting in a different best picked model. But scores more or less in similar range - difference may be because of different kernel for TE layers
Least important 6 layers:
    ModelOpt Spec: 27, 28, 29, 31, 32, 33
    TE Spec: 27, 28, 30, 31, 32, 33

Top 10 pruned candidates:
| num_layers | hidden_size | ffn_hidden_size | Params (B) | MMLU (ModelOpt Spec) | MMLU (TE Spec) |
|------------|-------------|-----------------|------------|----------------------|----------------|
| 34         | 3328        | 11264           | 5.99       | 0.390                | 0.397          |
| 30         | 3584        | 11776           | 5.99       | 0.572 [BEST]         | 0.575          |
| 36         | 3840        | 8192            | 5.98       | 0.511                | 0.511          |
| 36         | 3584        | 9216            | 5.98       | 0.477                | 0.497          |
| 36         | 3072        | 11776           | 5.97       | 0.278                | 0.252          |
| 32         | 3584        | 10752           | 5.96       | 0.542                | 0.541          |
| 36         | 3328        | 10240           | 5.92       | 0.365                | 0.412          |
| 34         | 3840        | 8704            | 5.91       | 0.537                | 0.539          |
| 30         | 4096        | 9216            | 5.90       | 0.566                | 0.591 [BEST]   |
| 34         | 3584        | 9728            | 5.89       | 0.499                | 0.510          |
  • Compare pruning results for Nemotron-Nano-9B-v2 -> 7B. MMLU scores slight difference but best pruned model selection same
Least important 8 layers (Before and After): [43, 44, 45, 46, 47, 48, 50, 52]

Top 10 pruned candidates:
| num_layers | hidden_size | mamba_num_heads | mamba_head_dim | ffn_hidden_size | Params (B) | MMLU (ModelOpt Spec) | MMLU (TE Spec) |
|------------|-------------|------------------|---------------|-----------------|------------|----------------------|----------------|
| 50         | 4480        | 128              | 56            | 15680           | 7.00       | 0.211                | 0.202          |
| 56         | 4096        | 96               | 80            | 14336           | 7.00       | 0.438                | 0.436          |
| 48         | 4352        | 120              | 80            | 13824           | 7.00       | 0.679 [BEST]         | 0.679 [BEST]   |
| 56         | 4352        | 112              | 80            | 10240           | 7.00       | 0.516                | 0.520          |
| 54         | 4480        | 104              | 80            | 11264           | 7.00       | 0.263                | 0.262          |
| 46         | 4480        | 128              | 72            | 14848           | 7.00       | 0.610                | 0.617          |
| 50         | 4480        | 112              | 64            | 15680           | 7.00       | 0.426                | 0.421          |
| 54         | 4096        | 112              | 80            | 13312           | 7.00       | 0.579                | 0.589          |
| 56         | 4352        | 120              | 72            | 10752           | 7.00       | 0.466                | 0.469          |
| 52         | 4352        | 120              | 72            | 12800           | 7.00       | 0.561                | 0.560          |
  • Compare pruning results for Qwen3-30B-A3B -> 24B. Previously there was a bug in hooks added so now we see a big improvement
Top 10 pruned candidates (~1 hour per candidate MMLU computation so skipped after 3):
| num_layers | hidden_size | num_attention_heads | num_moe_experts | Params (B)| MMLU (ModelOpt Spec) | MMLU (TE Spec) |
|------------|-------------|---------------------|-----------------|-----------|----------------------|----------------|
| 46         | 2048        | 28                  | 104             | 23.98B    | 0.663                | 0.698          |
| 40         | 2048        | 28                  | 120             | 23.95B    | 0.577                | 0.668          |
| 46         | 1792        | 24                  | 120             | 23.94B    | 0.435                | 0.500          |
| 46         | 2048        | 24                  | 104             | 23.88B    |                      |                |
| 40         | 2048        | 24                  | 120             | 23.87B    |                      |                |
| 46         | 1792        | 20                  | 120             | 23.85B    |                      |                |
| 40         | 2048        | 20                  | 120             | 23.78B    |                      |                |
| 46         | 2048        | 20                  | 104             | 23.78B    |                      |                |
| 42         | 2048        | 32                  | 112             | 23.62B    |                      |                |
| 48         | 1792        | 32                  | 112             | 23.54B    |                      |                |
  • Run pruning experiments for gptoss-20b (21B actually) -> 18B with TESpec. Seems like GPTOSS MMLU is dropping steeply even in 15% pruning
"Only considering atmost 40% for width and 20% for depth pruning hparams
Skipping hparams_to_skip=['num_attention_heads'] during search space generation...
        Search space for num_layers: [20, 22, 24]
        Search space for hidden_size: [2048, 2304, 2560, 2816, 2880]
        Search space for num_moe_experts: [24, 32]
        Search space for moe_ffn_hidden_size: [2048, 2304, 2560, 2816, 2880]
        Total search space in consideration: 150

Top 10 candidates with scores:
        {'num_layers': 20, 'hidden_size': 2880, 'num_moe_experts': 32, 'moe_ffn_hidden_size': 2880} -> 17.62B params, 0.3780 score
        {'num_layers': 22, 'hidden_size': 2880, 'num_moe_experts': 32, 'moe_ffn_hidden_size': 2560} -> 17.32B params, 0.4160 score [BEST SUBNET]
        {'num_layers': 20, 'hidden_size': 2880, 'num_moe_experts': 32, 'moe_ffn_hidden_size': 2816} -> 17.27B params, 0.3523 score
        {'num_layers': 20, 'hidden_size': 2816, 'num_moe_experts': 32, 'moe_ffn_hidden_size': 2880} -> 17.23B params, 0.3848 score
        {'num_layers': 22, 'hidden_size': 2560, 'num_moe_experts': 32, 'moe_ffn_hidden_size': 2880} -> 17.13B params, 0.3062 score
        {'num_layers': 24, 'hidden_size': 2880, 'num_moe_experts': 32, 'moe_ffn_hidden_size': 2304} -> 17.09B params, 0.3984 score
        {'num_layers': 22, 'hidden_size': 2816, 'num_moe_experts': 32, 'moe_ffn_hidden_size': 2560} -> 16.94B params, 0.3957 score
        {'num_layers': 20, 'hidden_size': 2816, 'num_moe_experts': 32, 'moe_ffn_hidden_size': 2816} -> 16.88B params, 0.3835 score
        {'num_layers': 22, 'hidden_size': 2560, 'num_moe_experts': 32, 'moe_ffn_hidden_size': 2816} -> 16.78B params, 0.2154 score
        {'num_layers': 24, 'hidden_size': 2304, 'num_moe_experts': 32, 'moe_ffn_hidden_size': 2880} -> 16.73B params, 0.0014 score"
  • Run pruning experiments for Nemotron-3-Nano-30B-A3B (31.5B actually) -> 24B with TESpec
Top 10 candidates with scores:
        {'num_layers': 46, 'hidden_size': 2688, 'mamba_num_heads': 64, 'num_moe_experts': 96, 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3712} -> 24.00B params, 0.0000 score
        {'num_layers': 52, 'hidden_size': 2048, 'mamba_num_heads': 64, 'num_moe_experts': 128, 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} -> 24.00B params, 0.2764 score
        {'num_layers': 48, 'hidden_size': 2688, 'mamba_num_heads': 56, 'num_moe_experts': 96, 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3712} -> 24.00B params, 0.6098 score
        {'num_layers': 52, 'hidden_size': 2560, 'mamba_num_heads': 64, 'num_moe_experts': 104, 'moe_ffn_hidden_size': 1792, 'moe_shared_expert_intermediate_size': 3328} -> 24.00B params, 0.6233 score
        {'num_layers': 48, 'hidden_size': 2688, 'mamba_num_heads': 64, 'num_moe_experts': 96, 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} -> 24.00B params, 0.6301 score [BEST SUBNET]
        {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 56, 'num_moe_experts': 96, 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 2560} -> 23.99B params, 0.6125 score
        {'num_layers': 52, 'hidden_size': 2688, 'mamba_num_heads': 48, 'num_moe_experts': 96, 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3328} -> 23.99B params, 0.4255 score
        {'num_layers': 50, 'hidden_size': 2048, 'mamba_num_heads': 64, 'num_moe_experts': 128, 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3584} -> 23.99B params, 0.2859 score
        {'num_layers': 50, 'hidden_size': 2688, 'mamba_num_heads': 56, 'num_moe_experts': 96, 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3072} -> 23.99B params, 0.6125 score
        {'num_layers': 42, 'hidden_size': 2304, 'mamba_num_heads': 40, 'num_moe_experts': 120, 'moe_ffn_hidden_size': 1856, 'moe_shared_expert_intermediate_size': 3584} -> 23.99B params, 0.0366 score

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ⚠️ TE has different kernels so pruned model may be slightly different because of different numerics
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅

Additional Information

OMNIML-3504

Summary by CodeRabbit

  • New Features

    • Full Transformer Engine support for Minitron pruning; no custom model spec required.
  • Bug Fixes

    • Resolved pruning hang on MoE models by correcting importance-hook behavior.
  • Documentation

    • Updated changelog and example README; bumped recommended container tag and expanded Docker run/mount guidance; adjusted release dates.
  • Improvements

    • Per-rank local activation handling, broader candidate caching, runtime router scoring mitigation, clearer pruning status messages.
  • Tests

    • Updated tests to validate Transformer Engine backend and adjusted dynamic-module expectations.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 11, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 11, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds Transformer Engine (TE) support across Megatron NAS and Minitron pruning: consolidates TE dynamic modules into the Megatron plugin, extracts fused TELayerNorm activations per rank, switches spec factories to TE factories, updates pruning searcher to local per-rank activations and candidate caching, and adapts examples and tests to TE paths.

Changes

Cohort / File(s) Summary
Changelog & Examples
CHANGELOG.rst, examples/megatron_bridge/README.md, examples/megatron_bridge/prune_minitron.py
Bumped changelog to 0.44 with TE Minitron feature and MoE pruning bugfix; updated NeMo Docker tag and docker run snippet; clarified pruning example log messages.
Megatron NAS plugin consolidation
modelopt/torch/nas/plugins/megatron.py, modelopt/torch/nas/plugins/transformer_engine.py
Moved TE-aware dynamic-module logic into megatron.py, removed the standalone transformer_engine.py, added get_te_mamba_stack_spec(...) public API and switched conversions to TE module equivalents.
Minitron pruning (mcore_minitron)
modelopt/torch/prune/plugins/mcore_minitron.py
Added hooks to extract fused TELayerNorm outputs; changed activation handling from PP-allgathered lists to per-rank local_activations dict APIs and checkpoint keys; store all_candidates_per_constraint; reinitialize all MoE dispatchers and toggle router expert-bias during evaluation; unpatch TE modules on cleanup.
Bridge / spec selection
modelopt/torch/utils/plugins/mbridge.py, tests/_test_utils/torch/megatron/models.py
Replaced static spec usage with factory calls (get_te_mamba_stack_spec, get_gpt_layer_with_transformer_engine_spec), force moe_grouped_gemm=False for TE paths, thread transformer_impl/moe_grouped_gemm into test utils, and simplify tokenizer use_fast handling.
Tests: dynamics & pruning
tests/gpu_megatron/.../test_megatron_*_dynamic_modules.py, tests/gpu_megatron/.../test_mcore_*_minitron_pruning.py, tests/_test_utils/torch/megatron/models.py
Updated tests to construct TE transformer paths, assert TE-specific dynamic classes, set bf16 flags where needed, adapt pruning tests to local_activations and new candidate/state shapes, and adjust pruning score key checks.

Sequence Diagram(s)

sequenceDiagram
  participant Searcher as MCoreMinitronSearcher
  participant Model as Megatron/TE Model
  participant Hook as ActivationHook
  participant Checkpoint as Per-rank Checkpoint
  participant Pruner as PruneRoutine

  rect rgba(100,149,237,0.5)
    Searcher->>Model: register ActivationHook on TELayerNormColumnParallelLinear
    Model-->>Hook: fused layernorm+linear forward outputs
    Hook->>Searcher: collect per-module activations (local_activations)
  end

  rect rgba(60,179,113,0.5)
    Searcher->>Checkpoint: set_local_activations_and_layer_scores(local_activations, layer_scores)
    Checkpoint-->>Searcher: saved per-rank activations
    Checkpoint->>Searcher: load local_activations for run_search
  end

  rect rgba(255,140,0,0.5)
    Searcher->>Pruner: invoke _prune using collected scores
    Pruner->>Model: apply pruning masks (no early break)
    Pruner->>Model: reinitialize token dispatcher if needed
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • AAnoosheh
  • realAsma
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.79% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main objectives: adding full TE spec support for Megatron pruning DynamicModules and fixing MoE-related bugs. It is concise, specific, and clearly conveys the primary changes.
Security Anti-Patterns ✅ Passed PR does not introduce security anti-patterns. trust_remote_code properly defaults to False with is_safe_repo() validation; no torch.load, eval/exec, or dependency violations found.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kmorabia/minitron-full-te-spec
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

@kevalmorabia97 kevalmorabia97 changed the title Add Full TE Spec support for Megatron Pruning DynamicModules Add Full TE Spec and GroupedMLP support for Megatron Pruning DynamicModules Mar 11, 2026
@codecov
Copy link

codecov bot commented Mar 11, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.29%. Comparing base (839fa3d) to head (da2f5dc).
⚠️ Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1024      +/-   ##
==========================================
- Coverage   70.30%   70.29%   -0.02%     
==========================================
  Files         227      227              
  Lines       25857    25857              
==========================================
- Hits        18179    18176       -3     
- Misses       7678     7681       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kevalmorabia97 kevalmorabia97 changed the title Add Full TE Spec and GroupedMLP support for Megatron Pruning DynamicModules Add Full TE Spec support for Megatron Pruning DynamicModules Mar 11, 2026
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/minitron-full-te-spec branch 6 times, most recently from 8f42e0f to cff7137 Compare March 12, 2026 12:03
@kevalmorabia97 kevalmorabia97 changed the title Add Full TE Spec support for Megatron Pruning DynamicModules Add Full TE Spec support for Megatron Pruning DynamicModules + MoE bug fixes Mar 12, 2026
@kevalmorabia97 kevalmorabia97 marked this pull request as ready for review March 12, 2026 20:26
@kevalmorabia97 kevalmorabia97 requested review from a team as code owners March 12, 2026 20:26
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/prune/plugins/mcore_minitron.py (1)

918-935: Add checkpoint key validation before restoring _activations.

m._activations = local_activations[n] will raise a raw KeyError on checkpoint/model drift. A pre-check with a clear error message will make resume failures diagnosable.

Proposed hardening
     def set_local_activations_and_layer_scores(
@@
         print_rank_0("Loading activations and scores from per-rank checkpoint...")
         for layer in self.model.decoder.layers:
             layer._scores = layer_scores[layer.layer_number]
+        expected_keys = [
+            n for n, m in self.model.named_modules() if hasattr(m, "_activations")
+        ]
+        missing = [k for k in expected_keys if k not in local_activations]
+        if missing:
+            raise KeyError(
+                f"Missing activation entries for modules: {missing[:8]}"
+                + (" ..." if len(missing) > 8 else "")
+            )
         for n, m in self.model.named_modules():
             if hasattr(m, "_activations"):
                 m._activations = local_activations[n]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 918 - 935, The
current set_local_activations_and_layer_scores method assigns m._activations =
local_activations[n] without validating the key, which will raise a raw KeyError
on checkpoint/model drift; update set_local_activations_and_layer_scores to
check if n is in local_activations before assignment (iterate over
self.model.named_modules()), and if missing either raise a clear ValueError that
includes the module name n and a summary of available keys (e.g.,
list(local_activations.keys())) or log a descriptive warning and skip restoring
that module, so failures are diagnosable; reference the method name
set_local_activations_and_layer_scores and attributes _activations,
local_activations, and model.named_modules() when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 870-874: You’re overwriting TE modules’ return_layernorm_output
state unconditionally; instead capture each TELayerNormColumnParallelLinear’s
original return_layernorm_output before changing it (e.g., store in a dict keyed
by id(module) or attach a private attribute like _orig_return_layernorm_output
on the module) when you set it to True, and in the cleanup loop restore each
module’s original value rather than forcing False; apply the same
save-and-restore pattern for the other similar block referenced (lines
~999-1006) so original behavior is preserved after pruning/search.

---

Nitpick comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 918-935: The current set_local_activations_and_layer_scores method
assigns m._activations = local_activations[n] without validating the key, which
will raise a raw KeyError on checkpoint/model drift; update
set_local_activations_and_layer_scores to check if n is in local_activations
before assignment (iterate over self.model.named_modules()), and if missing
either raise a clear ValueError that includes the module name n and a summary of
available keys (e.g., list(local_activations.keys())) or log a descriptive
warning and skip restoring that module, so failures are diagnosable; reference
the method name set_local_activations_and_layer_scores and attributes
_activations, local_activations, and model.named_modules() when making the
change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7ec3047d-20d8-40ab-b3d9-6a92aa5ec6c0

📥 Commits

Reviewing files that changed from the base of the PR and between 7b198d6 and 8225445.

📒 Files selected for processing (2)
  • modelopt/torch/prune/plugins/mcore_minitron.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Copy link
Collaborator

@ChenhanYu ChenhanYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: Add Full TE Spec support for Megatron Pruning DynamicModules + MoE bug fixes

Summary

This PR migrates Megatron pruning from custom ModelOpt spec to standard Transformer Engine (TE) spec, enabling retirement of the ModelOpt spec. Key changes:

  • New TE-specific DynamicModule classes (_DynamicTEParallelLinear, _DynamicTEColumnParallelLinear, etc.) that use TE's in_features/out_features naming
  • Fused LayerNorm handling via return_layernorm_output=True patching on TELayerNormColumnParallelLinear
  • Activation collection simplified from allgather across PP ranks to per-rank local storage
  • Two MoE bug fixes: (1) token dispatcher reinit was only applied to the first MoE layer (break removal), (2) pre_mlp_layernorm hooks were missing for MoE layers

Overall well-structured with thorough testing.

Attention Required

  1. TE is now a hard requirement for pruning. The HAS_TE conditional was removed and DotProductAttention fallback dropped — assert isinstance(self.core_attention, TEDotProductAttention) will crash without TE. This is an intentional design shift (ModelOpt spec retirement), but worth confirming no internal users still run without TE.

  2. Activation storage format changed (checkpoint-breaking). activations_per_rank (list) → local_activations (dict). Existing pruning checkpoints saved with the old format will fail to load. The PR description doesn't mention a migration path.

bf16=bf16,
# MoE-specific parameters
moe_router_dtype=None,
moe_grouped_gemm=moe_grouped_gemm,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed from "fp32" to None. Is this intentional or an unrelated cleanup? Running MoE routers without explicit fp32 dtype can affect numerical stability.

@ChenhanYu ChenhanYu self-requested a review March 18, 2026 21:03
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/_test_utils/torch/megatron/models.py (1)

308-309: Reject unknown transformer_impl values in the Mamba helper.

Unlike get_mcore_gpt_model, this helper treats every non-"transformer_engine" value as "modelopt". A typo or a caller passing "local" will silently stop exercising the intended path.

♻️ Proposed fix
 def get_mcore_mamba_hybrid_model(
@@
     sequence_parallel: bool = False,
     transformer_impl: str = "modelopt",
@@
 ) -> MambaModel:
@@
     """
     assert HAS_MAMBA, "Mamba not installed"
+    assert transformer_impl in ["modelopt", "transformer_engine"]
@@
     if transformer_impl == "transformer_engine":
         mamba_spec = get_te_mamba_stack_spec(moe_grouped_gemm=moe_grouped_gemm)
-    else:
+    else:  # transformer_impl == "modelopt"
         mamba_spec = get_mamba_stack_modelopt_spec(remap_te_layernorm=True)

Also applies to: 390-393

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/_test_utils/torch/megatron/models.py` around lines 308 - 309, The Mamba
helper currently treats any transformer_impl value other than
"transformer_engine" as "modelopt", allowing silent typos; update the helper(s)
that read the transformer_impl parameter (around the transformer_impl =
"modelopt" default and the code at the other occurrence) to perform explicit
validation: accept only the supported strings ("modelopt" and
"transformer_engine") and raise a ValueError with a clear message if an unknown
value is passed, rather than silently defaulting—adjust both occurrences (lines
near transformer_impl and the block at 390-393) to enforce this check.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py`:
- Around line 74-75: The tests set transformer_impl="transformer_engine" but
lack a guard, so add a Transformer Engine availability skip: either call
pytest.importorskip("megatron.core.extensions.transformer_engine") at module
scope in the test module (alongside the existing skip_if_no_mamba() call) or
implement a helper skip_if_no_transformer_engine() in
_test_utils/import_helper.py and invoke it next to skip_if_no_mamba(); ensure
the check runs before any model construction that uses transformer_impl to avoid
runtime failures.

---

Nitpick comments:
In `@tests/_test_utils/torch/megatron/models.py`:
- Around line 308-309: The Mamba helper currently treats any transformer_impl
value other than "transformer_engine" as "modelopt", allowing silent typos;
update the helper(s) that read the transformer_impl parameter (around the
transformer_impl = "modelopt" default and the code at the other occurrence) to
perform explicit validation: accept only the supported strings ("modelopt" and
"transformer_engine") and raise a ValueError with a clear message if an unknown
value is passed, rather than silently defaulting—adjust both occurrences (lines
near transformer_impl and the block at 390-393) to enforce this check.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d8683652-f5b5-467a-9641-54568a74e6f6

📥 Commits

Reviewing files that changed from the base of the PR and between 8225445 and 4b54afb.

📒 Files selected for processing (6)
  • CHANGELOG.rst
  • examples/megatron_bridge/prune_minitron.py
  • modelopt/torch/nas/plugins/megatron.py
  • modelopt/torch/prune/plugins/mcore_minitron.py
  • tests/_test_utils/torch/megatron/models.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • examples/megatron_bridge/prune_minitron.py
  • CHANGELOG.rst
  • modelopt/torch/nas/plugins/megatron.py
  • modelopt/torch/prune/plugins/mcore_minitron.py

Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/minitron-full-te-spec branch 2 times, most recently from cf71476 to 067e80a Compare March 19, 2026 17:36
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/prune/plugins/mcore_minitron.py (1)

477-513: ⚠️ Potential issue | 🟡 Minor

Make candidate validation exception-safe.

This block temporarily disables enable_expert_bias and mutates the model into each candidate subnet, but restoration only happens on the happy path. If _prune() or eval_score() raises, the process keeps the patched router flags and partially mutated model state. Wrap the validation section in a try/finally and restore the router flags / max subnet there.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 477 - 513, The
validation loop temporarily disables router flags and mutates the model but only
restores them on the happy path; make it exception-safe by wrapping the
candidate validation loop (the logic that iterates top_k_candidates and calls
_prune, eval_score, sample, and saves/restores model.decoder.layers) in a
try/finally so that in the finally you: (1) re-enable every module in
_routers_with_expert_bias by setting m.enable_expert_bias = True, and (2)
restore the model to the max subnet and original layer numbering (use
sample(self.model, sample_func=max) and reset layer.layer_number from the saved
start_layer_number and reassign self.model.decoder.layers) to ensure cleanup
runs even if _prune() or eval_score() throws.
♻️ Duplicate comments (2)
modelopt/torch/prune/plugins/mcore_minitron.py (1)

277-292: ⚠️ Potential issue | 🟠 Major

Restore TE return_layernorm_output to the original value.

The registration path flips every fused TE module to True, but cleanup() always forces False. Any module that started as True will change behavior after pruning/search. Please store the original flag per module and restore that exact value; run_search() should also call cleanup() from a finally block so exceptions don’t leak the patch.

Proposed fix
 class ImportanceEstimatorRegistry:
     def __init__(self, model: DynamicModule):
         """Initialize the registry."""
         assert isinstance(model, _DynamicMCoreLanguageModel), "Model must be a DynamicModule"
         self.model = model
         self._hooks: list[tuple[nn.Module, Any]] = []  # List of (module, hook_handle) tuples
+        self._te_ln_linear_prev_flags: dict[nn.Module, bool] = {}
@@
     def cleanup(self) -> None:
         """Remove all registered hooks and temporary attributes."""
         # Remove all hooks
         for _, handle in self._hooks:
             handle.remove()
         self._hooks.clear()
 
-        # Unpatch return_layernorm_output on fused TELayerNormColumnParallelLinear modules
-        for m in self.model.modules():
-            if isinstance(m, TELayerNormColumnParallelLinear):
-                m.return_layernorm_output = False
+        for m, prev_flag in self._te_ln_linear_prev_flags.items():
+            m.return_layernorm_output = prev_flag
+        self._te_ln_linear_prev_flags.clear()
@@
     for m in module.modules():
         if isinstance(m, TELayerNormColumnParallelLinear):
+            if m not in registry._te_ln_linear_prev_flags:
+                registry._te_ln_linear_prev_flags[m] = m.return_layernorm_output
             m.return_layernorm_output = True

Also applies to: 881-885, 1014-1016

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 277 - 292, The
cleanup flow currently forces every fused TE module's return_layernorm_output to
False and the registration path sets them to True, which alters module behavior
after pruning; update the code (in registration/patching logic that flips fused
TE modules and in cleanup()) to record each module's original
return_layernorm_output value (e.g., store a per-module map keyed by module
identity) and restore that exact original boolean in cleanup() instead of
hardcoding False, and ensure run_search() invokes cleanup() inside a finally
block so the original flags are restored even if an exception occurs; refer to
functions/methods cleanup(), run_search(), and the return_layernorm_output flag
on the fused TE modules when making the changes.
tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (1)

74-75: ⚠️ Potential issue | 🟡 Minor

Guard these TE-only tests before model/plugin import.

These cases now force transformer_impl="transformer_engine", but the module still only calls skip_if_no_mamba(). In environments with Mamba installed and Transformer Engine absent, this file will fail instead of skipping cleanly. Add pytest.importorskip("megatron.core.extensions.transformer_engine") or a shared helper next to skip_if_no_mamba().

Also applies to: 155-156, 274-275

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py`
around lines 74 - 75, The tests set transformer_impl="transformer_engine" but
only call skip_if_no_mamba(), which lets cases run when Mamba exists but
Transformer Engine does not; add a guard that imports or skips the TE module
before importing or instantiating the model/plugin (e.g., call
pytest.importorskip("megatron.core.extensions.transformer_engine") or create/use
a shared helper that does that) and apply the same change for the other
occurrences where transformer_impl="transformer_engine" is used; keep the
existing skip_if_no_mamba() but ensure the TE import-or-skip runs earlier to
prevent import failures.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Line 37: Wrap the module-level import of TELayerNormColumnParallelLinear in a
try/except that sets TELayerNormColumnParallelLinear = None and a HAS_TE boolean
flag (True on success, False on ImportError), then conditionally execute any
Transformer-Engine-specific logic only when HAS_TE is True; specifically guard
every place that references TELayerNormColumnParallelLinear (e.g., isinstance
checks and TE hook registrations) and the TE-specific hook registration blocks
so they run only if HAS_TE is True.

In `@modelopt/torch/utils/plugins/mbridge.py`:
- Around line 95-103: The else branch calls provider.num_moe_experts and
provider.qk_layernorm without defensive checks; update the branch around
transformer_layer_spec assignment (get_gpt_layer_with_transformer_engine_spec)
to either guard these attribute accesses with hasattr(provider,
"num_moe_experts") and hasattr(provider, "qk_layernorm") and handle missing
attributes (fallback values or raise a clear error), or add a concise
comment/docstring near the else branch documenting the assumption that provider
is a GPTModelProvider with those attributes (and cite the external megatron
source). Ensure references to provider.num_moe_experts, provider.qk_layernorm,
and get_gpt_layer_with_transformer_engine_spec are updated accordingly so future
callers see the validation or documented contract.

In `@tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py`:
- Around line 75-76: The test hardcodes transformer_impl="transformer_engine"
but only calls skip_if_no_mamba(), so when Mamba exists and TE does not the test
will fail; add a new helper skip_if_no_transformer_engine() to
tests/_test_utils/import_helper.py (modeled on the TE guards used in
test_megatron.py around line ~964) and call skip_if_no_transformer_engine() at
the module level in
tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py
alongside the existing skip_if_no_mamba() so the test is cleanly skipped when
Transformer Engine is unavailable.

---

Outside diff comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 477-513: The validation loop temporarily disables router flags and
mutates the model but only restores them on the happy path; make it
exception-safe by wrapping the candidate validation loop (the logic that
iterates top_k_candidates and calls _prune, eval_score, sample, and
saves/restores model.decoder.layers) in a try/finally so that in the finally
you: (1) re-enable every module in _routers_with_expert_bias by setting
m.enable_expert_bias = True, and (2) restore the model to the max subnet and
original layer numbering (use sample(self.model, sample_func=max) and reset
layer.layer_number from the saved start_layer_number and reassign
self.model.decoder.layers) to ensure cleanup runs even if _prune() or
eval_score() throws.

---

Duplicate comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 277-292: The cleanup flow currently forces every fused TE module's
return_layernorm_output to False and the registration path sets them to True,
which alters module behavior after pruning; update the code (in
registration/patching logic that flips fused TE modules and in cleanup()) to
record each module's original return_layernorm_output value (e.g., store a
per-module map keyed by module identity) and restore that exact original boolean
in cleanup() instead of hardcoding False, and ensure run_search() invokes
cleanup() inside a finally block so the original flags are restored even if an
exception occurs; refer to functions/methods cleanup(), run_search(), and the
return_layernorm_output flag on the fused TE modules when making the changes.

In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py`:
- Around line 74-75: The tests set transformer_impl="transformer_engine" but
only call skip_if_no_mamba(), which lets cases run when Mamba exists but
Transformer Engine does not; add a guard that imports or skips the TE module
before importing or instantiating the model/plugin (e.g., call
pytest.importorskip("megatron.core.extensions.transformer_engine") or create/use
a shared helper that does that) and apply the same change for the other
occurrences where transformer_impl="transformer_engine" is used; keep the
existing skip_if_no_mamba() but ensure the TE import-or-skip runs earlier to
prevent import failures.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: bb96a786-c906-4909-a05e-e62d8eccbf69

📥 Commits

Reviewing files that changed from the base of the PR and between 4b54afb and cf71476.

📒 Files selected for processing (12)
  • CHANGELOG.rst
  • examples/megatron_bridge/README.md
  • examples/megatron_bridge/prune_minitron.py
  • modelopt/torch/nas/plugins/megatron.py
  • modelopt/torch/nas/plugins/transformer_engine.py
  • modelopt/torch/prune/plugins/mcore_minitron.py
  • modelopt/torch/utils/plugins/mbridge.py
  • tests/_test_utils/torch/megatron/models.py
  • tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py
  • tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/nas/plugins/transformer_engine.py
✅ Files skipped from review due to trivial changes (2)
  • modelopt/torch/nas/plugins/megatron.py
  • tests/_test_utils/torch/megatron/models.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • examples/megatron_bridge/prune_minitron.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
  • tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py
  • examples/megatron_bridge/README.md
  • CHANGELOG.rst

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/prune/plugins/mcore_minitron.py (1)

476-513: ⚠️ Potential issue | 🟠 Major

Restore enable_expert_bias in a finally block.

If eval_score() or checkpoint saving raises inside this loop, the model is left with expert bias permanently disabled for the rest of the search.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 476 - 513, The
loop temporarily disables module expert bias by setting m.enable_expert_bias =
False and collects them in _routers_with_expert_bias, but if eval_score() or
save_search_checkpoint() throws the routers remain disabled; wrap the validation
phase (the for candidate in tqdm(...) loop and its inner pruning/eval/save
logic) in a try/finally (or surround the whole top-k validation block) and in
the finally re-enable each router by setting m.enable_expert_bias = True (using
the existing _routers_with_expert_bias list) so expert bias is always restored
even on exceptions.
♻️ Duplicate comments (2)
modelopt/torch/prune/plugins/mcore_minitron.py (2)

37-37: ⚠️ Potential issue | 🟠 Major

Gate the TE-specific import like the rest of the optional stack.

This makes the pruning plugin fail import whenever Transformer Engine is absent, even if the caller never exercises the TE path. Please wrap the import in try/except ImportError and guard the TE-specific hook logic with a feature flag.

As per coding guidelines, modelopt/**/*.py: Avoid hard imports of optional dependencies at module level; gate features by install extras ([onnx], [hf], [all]).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/plugins/mcore_minitron.py` at line 37, The module
currently hard-imports TELayerNormColumnParallelLinear which causes import
failure when Transformer Engine (TE) is absent; wrap the import in a try/except
ImportError and set a module-level feature flag (e.g., _HAS_TE = True/False)
accordingly, then guard all TE-specific logic/hook usage (references to
TELayerNormColumnParallelLinear and any TE-only hooks in the plugin class in
mcore_minitron.py) behind that flag so importing the module succeeds even when
TE is not installed.

881-885: ⚠️ Potential issue | 🟠 Major

Restore each module’s original return_layernorm_output value.

This still patches every fused TE module to True and later forces False globally. Any module that started with True is silently mutated after cleanup, and the blanket patch is broader than the modules you actually hook.

Also applies to: 1014-1016

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 881 - 885, The
cleanup loop currently forces return_layernorm_output=False on every
TELayerNormColumnParallelLinear in self.model, mutating modules that may have
originally been True; instead track which specific modules you patched (e.g.,
store them in a list when you set return_layernorm_output=True where you install
hooks) and in the cleanup only restore each tracked module's original value
(store the original value per-module when patching). Update the code that sets
return_layernorm_output to record (module, original_value) and replace the
current iteration over self.model.modules() with iteration over that recorded
list to restore original_value on the same TELayerNormColumnParallelLinear
instances you modified.
🧹 Nitpick comments (1)
modelopt/torch/prune/plugins/mcore_minitron.py (1)

186-190: The checkpoint payload types no longer match the data you store.

local_activations is not uniformly Tensor here: the hidden-size collector stores a nested dict keyed by module id, the SequentialMLP collector stores a dict of tensors, and layer_scores are Python floats (.item()), not tensors. Tighten these aliases and signatures so mypy and checkpoint consumers see the real payload shape.

As per coding guidelines, **/*.py: Use mypy for type checking on Python code (configured in pyproject.toml).

Also applies to: 208-214, 914-946

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 186 - 190, The
checkpoint payload type hints are incorrect: adjust the annotations for
local_activations and layer_scores (and any related aliases near sorted_layers
and all_candidates_per_constraint) to match what is actually stored — make
local_activations a union type reflecting both nested dicts (dict[str, dict[str,
torch.Tensor]]) and flat dicts (dict[str, torch.Tensor]) produced by the
hidden-size collector and SequentialMLP collector, and change layer_scores from
dict[int, torch.Tensor] to dict[int, float] (since .item() is stored); ensure
sorted_layers remains Optional[List[int]] (1-indexed) and
all_candidates_per_constraint stays dict[float, list[CandidateSubnet]]; update
any type aliases or function signatures that construct or consume these payloads
so mypy and checkpoint serializers see the real shapes (search for
local_activations, layer_scores, sorted_layers, all_candidates_per_constraint,
CandidateSubnet to locate all usages).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/nas/plugins/megatron.py`:
- Around line 390-403: The specialized TE parallel linear class registers only
"out_features" (via self._register_dynamic_attribute("out_features", ...)) but
not "in_features", causing stale dimension metadata after slicing; update the
class to also call self._register_dynamic_attribute("in_features", <callable>)
mirroring the parent _DynamicTEParallelLinear behavior (compute input dim from
mod.config.kv_channels and any active/group counts as appropriate), and ensure
"in_features" and "out_features" are kept in sync with the sliced weight/bias
(use the same pattern/closures used for "out_features" and for
bias/_get_weight/_get_bias/_get_ln_param to compute current dimensions).
- Around line 25-31: The module currently imports transformer_engine and the
Megatron TE extension classes (transformer_engine, TEColumnParallelLinear,
TEDotProductAttention, TELayerNormColumnParallelLinear, TERowParallelLinear)
unconditionally which breaks imports when TE is absent; wrap these imports in a
try/except and set a HAS_TE boolean (mirroring the HAS_MAMBA pattern already in
this file) so downstream registration of TE-specific features only occurs when
HAS_TE is True, and ensure any code that references those TE symbols is guarded
by the HAS_TE check similar to the pattern used in
modelopt/torch/quantization/plugins/megatron.py.

---

Outside diff comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 476-513: The loop temporarily disables module expert bias by
setting m.enable_expert_bias = False and collects them in
_routers_with_expert_bias, but if eval_score() or save_search_checkpoint()
throws the routers remain disabled; wrap the validation phase (the for candidate
in tqdm(...) loop and its inner pruning/eval/save logic) in a try/finally (or
surround the whole top-k validation block) and in the finally re-enable each
router by setting m.enable_expert_bias = True (using the existing
_routers_with_expert_bias list) so expert bias is always restored even on
exceptions.

---

Duplicate comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Line 37: The module currently hard-imports TELayerNormColumnParallelLinear
which causes import failure when Transformer Engine (TE) is absent; wrap the
import in a try/except ImportError and set a module-level feature flag (e.g.,
_HAS_TE = True/False) accordingly, then guard all TE-specific logic/hook usage
(references to TELayerNormColumnParallelLinear and any TE-only hooks in the
plugin class in mcore_minitron.py) behind that flag so importing the module
succeeds even when TE is not installed.
- Around line 881-885: The cleanup loop currently forces
return_layernorm_output=False on every TELayerNormColumnParallelLinear in
self.model, mutating modules that may have originally been True; instead track
which specific modules you patched (e.g., store them in a list when you set
return_layernorm_output=True where you install hooks) and in the cleanup only
restore each tracked module's original value (store the original value
per-module when patching). Update the code that sets return_layernorm_output to
record (module, original_value) and replace the current iteration over
self.model.modules() with iteration over that recorded list to restore
original_value on the same TELayerNormColumnParallelLinear instances you
modified.

---

Nitpick comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 186-190: The checkpoint payload type hints are incorrect: adjust
the annotations for local_activations and layer_scores (and any related aliases
near sorted_layers and all_candidates_per_constraint) to match what is actually
stored — make local_activations a union type reflecting both nested dicts
(dict[str, dict[str, torch.Tensor]]) and flat dicts (dict[str, torch.Tensor])
produced by the hidden-size collector and SequentialMLP collector, and change
layer_scores from dict[int, torch.Tensor] to dict[int, float] (since .item() is
stored); ensure sorted_layers remains Optional[List[int]] (1-indexed) and
all_candidates_per_constraint stays dict[float, list[CandidateSubnet]]; update
any type aliases or function signatures that construct or consume these payloads
so mypy and checkpoint serializers see the real shapes (search for
local_activations, layer_scores, sorted_layers, all_candidates_per_constraint,
CandidateSubnet to locate all usages).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 8d87f835-546c-4fe1-995b-e21625d0dee5

📥 Commits

Reviewing files that changed from the base of the PR and between cf71476 and 067e80a.

📒 Files selected for processing (6)
  • CHANGELOG.rst
  • examples/megatron_bridge/prune_minitron.py
  • modelopt/torch/nas/plugins/megatron.py
  • modelopt/torch/prune/plugins/mcore_minitron.py
  • tests/_test_utils/torch/megatron/models.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
✅ Files skipped from review due to trivial changes (1)
  • tests/_test_utils/torch/megatron/models.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • examples/megatron_bridge/prune_minitron.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
  • CHANGELOG.rst

Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/minitron-full-te-spec branch from 067e80a to 74184b8 Compare March 19, 2026 19:43
# MoE-specific parameters
moe_router_dtype=None,
moe_grouped_gemm=moe_grouped_gemm,
moe_router_dtype="fp32",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MOE router type FP32 is needed in Nemotron models

Copy link
Collaborator Author

@kevalmorabia97 kevalmorabia97 Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was seeing some issue for the pruning test model so didnt set it here. For Nemotron3 Nano pruning on actual model, I used whatever it has default and works fine

Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 merged commit f22f4f5 into main Mar 20, 2026
42 checks passed
@kevalmorabia97 kevalmorabia97 deleted the kmorabia/minitron-full-te-spec branch March 20, 2026 10:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants